-
Notifications
You must be signed in to change notification settings - Fork 32
simplify JAX API #2998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
simplify JAX API #2998
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the import_petab_problem function to directly return a JAXProblem instance when jax=True, instead of returning a JAXModel that users must manually wrap. This simplifies the API by eliminating the need for users to explicitly create JAXProblem(jax_model, petab_problem).
Key changes:
- Modified
import_petab_problemto returnJAXProbleminstead ofJAXModelwhenjax=True - Updated all test files and examples to use the simplified API by removing manual
JAXProbleminstantiation - Removed unused
JAXProblemimport from test files where it's no longer needed
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| python/sdist/amici/petab/petab_import.py | Updated return type annotation and implementation to create and return JAXProblem directly when jax=True |
| tests/benchmark_models/test_petab_benchmark_jax.py | Removed manual JAXProblem instantiation and unused import |
| python/tests/test_jax.py | Updated multiple test functions to use the simplified API without manual JAXProblem creation |
| doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb | Updated tutorial to reflect new API, including documentation and variable references |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2998 +/- ##
==========================================
- Coverage 76.73% 76.05% -0.68%
==========================================
Files 309 309
Lines 19884 19884
Branches 1498 1498
==========================================
- Hits 15258 15123 -135
- Misses 4613 4748 +135
Partials 13 13
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
|



remove user facing instantiations of JaxModel in favour of directly instantiating JaxProblem.